// SPDX-FileCopyrightText: Copyright 2024 shadPS4 Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later
#pragma once

#include <span>
#include <vector>
#include <boost/container/small_vector.hpp>
#include <boost/container/static_vector.hpp>
#include "common/assert.h"
#include "common/types.h"
#include "shader_recompiler/backend/bindings.h"
#include "shader_recompiler/frontend/copy_shader.h"
#include "shader_recompiler/frontend/tessellation.h"
#include "shader_recompiler/ir/attribute.h"
#include "shader_recompiler/ir/passes/srt.h"
#include "shader_recompiler/ir/reg.h"
#include "shader_recompiler/ir/type.h"
#include "shader_recompiler/params.h"
#include "shader_recompiler/profile.h"
#include "shader_recompiler/runtime_info.h"
#include "video_core/amdgpu/resource.h"

namespace Shader {

static constexpr size_t NumUserDataRegs = 16;
static constexpr size_t NumImages = 64;
static constexpr size_t NumBuffers = 32;
static constexpr size_t NumSamplers = 16;
static constexpr size_t NumFMasks = 8;

enum class TextureType : u32 {
    Color1D,
    ColorArray1D,
    Color2D,
    ColorArray2D,
    Color3D,
    ColorCube,
    Buffer,
};
constexpr u32 NUM_TEXTURE_TYPES = 7;

enum class BufferType : u32 {
    Guest,
    ReadConstUbo,
    GdsBuffer,
    SharedMemory,
};

struct Info;

struct BufferResource {
    u32 sharp_idx;
    IR::Type used_types;
    AmdGpu::Buffer inline_cbuf;
    BufferType buffer_type;
    u8 instance_attrib{};
    bool is_written{};
    bool is_formatted{};

    bool IsSpecial() const noexcept {
        return buffer_type != BufferType::Guest;
    }

    bool IsStorage(const AmdGpu::Buffer& buffer, const Profile& profile) const noexcept {
        return buffer.GetSize() > profile.max_ubo_size || is_written;
    }

    [[nodiscard]] constexpr AmdGpu::Buffer GetSharp(const Info& info) const noexcept;
};
using BufferResourceList = boost::container::small_vector<BufferResource, NumBuffers>;

struct ImageResource {
    u32 sharp_idx;
    bool is_depth{};
    bool is_atomic{};
    bool is_array{};
    bool is_written{};

    [[nodiscard]] constexpr AmdGpu::Image GetSharp(const Info& info) const noexcept;
};
using ImageResourceList = boost::container::small_vector<ImageResource, NumImages>;

struct SamplerResource {
    u32 sharp_idx;
    AmdGpu::Sampler inline_sampler{};
    u32 associated_image : 4;
    u32 disable_aniso : 1;

    constexpr AmdGpu::Sampler GetSharp(const Info& info) const noexcept;
};
using SamplerResourceList = boost::container::small_vector<SamplerResource, NumSamplers>;

struct FMaskResource {
    u32 sharp_idx;

    constexpr AmdGpu::Image GetSharp(const Info& info) const noexcept;
};
using FMaskResourceList = boost::container::small_vector<FMaskResource, NumFMasks>;

struct PushData {
    static constexpr u32 Step0Index = 0;
    static constexpr u32 Step1Index = 1;
    static constexpr u32 XOffsetIndex = 2;
    static constexpr u32 YOffsetIndex = 3;
    static constexpr u32 XScaleIndex = 4;
    static constexpr u32 YScaleIndex = 5;
    static constexpr u32 UdRegsIndex = 6;
    static constexpr u32 BufOffsetIndex = UdRegsIndex + NumUserDataRegs / 4;

    u32 step0;
    u32 step1;
    float xoffset;
    float yoffset;
    float xscale;
    float yscale;
    std::array<u32, NumUserDataRegs> ud_regs;
    std::array<u8, NumBuffers> buf_offsets;

    void AddOffset(u32 binding, u32 offset) {
        ASSERT(offset < 256 && binding < buf_offsets.size());
        buf_offsets[binding] = offset;
    }
};
static_assert(sizeof(PushData) <= 128,
              "PushData size is greater than minimum size guaranteed by Vulkan spec");

/**
 * Contains general information generated by the shader recompiler for an input program.
 */
struct Info {
    struct AttributeFlags {
        bool Get(IR::Attribute attrib, u32 comp = 0) const {
            return flags[Index(attrib)] & (1 << comp);
        }

        bool GetAny(IR::Attribute attrib) const {
            return flags[Index(attrib)];
        }

        void Set(IR::Attribute attrib, u32 comp = 0) {
            flags[Index(attrib)] |= (1 << comp);
        }

        u32 NumComponents(IR::Attribute attrib) const {
            return 4;
        }

        static size_t Index(IR::Attribute attrib) {
            return static_cast<size_t>(attrib);
        }

        std::array<u8, IR::NumAttributes> flags;
    };
    AttributeFlags loads{};
    AttributeFlags stores{};

    struct UserDataMask {
        void Set(IR::ScalarReg reg) noexcept {
            mask |= 1 << static_cast<u32>(reg);
        }

        u32 Index(IR::ScalarReg reg) const noexcept {
            const u32 reg_mask = (1 << static_cast<u32>(reg)) - 1;
            return std::popcount(mask & reg_mask);
        }

        u32 NumRegs() const noexcept {
            return std::popcount(mask);
        }

        u32 mask;
    };
    UserDataMask ud_mask{};

    CopyShaderData gs_copy_data;
    u32 uses_patches{};

    BufferResourceList buffers;
    ImageResourceList images;
    SamplerResourceList samplers;
    FMaskResourceList fmasks;

    PersistentSrtInfo srt_info;
    std::vector<u32> flattened_ud_buf;

    IR::ScalarReg tess_consts_ptr_base = IR::ScalarReg::Max;
    s32 tess_consts_dword_offset = -1;

    std::span<const u32> user_data;
    Stage stage;
    LogicalStage l_stage;

    u64 pgm_hash{};
    VAddr pgm_base;
    bool has_storage_images{};
    bool has_discard{};
    bool has_image_gather{};
    bool has_image_query{};
    bool uses_lane_id{};
    bool uses_group_quad{};
    bool uses_group_ballot{};
    bool uses_shared{};
    bool uses_fp16{};
    bool uses_fp64{};
    bool uses_pack_10_11_11{};
    bool uses_unpack_10_11_11{};
    bool stores_tess_level_outer{};
    bool stores_tess_level_inner{};
    bool translation_failed{};
    bool has_readconst{};
    u8 mrt_mask{0u};
    bool has_fetch_shader{false};
    u32 fetch_shader_sgpr_base{0u};

    explicit Info(Stage stage_, LogicalStage l_stage_, ShaderParams params)
        : stage{stage_}, l_stage{l_stage_}, pgm_hash{params.hash}, pgm_base{params.Base()},
          user_data{params.user_data} {}

    template <typename T>
    inline T ReadUdSharp(u32 sharp_idx) const noexcept {
        return *reinterpret_cast<const T*>(&flattened_ud_buf[sharp_idx]);
    }

    template <typename T>
    T ReadUdReg(u32 ptr_index, u32 dword_offset) const noexcept {
        T data;
        const u32* base = user_data.data();
        if (ptr_index != IR::NumScalarRegs) {
            std::memcpy(&base, &user_data[ptr_index], sizeof(base));
            base = reinterpret_cast<const u32*>(VAddr(base) & 0xFFFFFFFFFFFFULL);
        }
        std::memcpy(&data, base + dword_offset, sizeof(T));
        return data;
    }

    void PushUd(Backend::Bindings& bnd, PushData& push) const {
        u32 mask = ud_mask.mask;
        while (mask) {
            const u32 index = std::countr_zero(mask);
            ASSERT(bnd.user_data < NumUserDataRegs && index < NumUserDataRegs);
            mask &= ~(1U << index);
            push.ud_regs[bnd.user_data++] = user_data[index];
        }
    }

    void AddBindings(Backend::Bindings& bnd) const {
        bnd.buffer += buffers.size();
        bnd.unified += buffers.size() + images.size() + samplers.size();
        bnd.user_data += ud_mask.NumRegs();
    }

    void RefreshFlatBuf() {
        flattened_ud_buf.resize(srt_info.flattened_bufsize_dw);
        ASSERT(user_data.size() <= NumUserDataRegs);
        std::memcpy(flattened_ud_buf.data(), user_data.data(), user_data.size_bytes());
        // Run the JIT program to walk the SRT and write the leaves to a flat buffer
        if (srt_info.walker_func) {
            srt_info.walker_func(user_data.data(), flattened_ud_buf.data());
        }
    }

    void ReadTessConstantBuffer(TessellationDataConstantBuffer& tess_constants) const {
        ASSERT(tess_consts_dword_offset >= 0); // We've already tracked the V# UD
        auto buf = ReadUdReg<AmdGpu::Buffer>(static_cast<u32>(tess_consts_ptr_base),
                                             static_cast<u32>(tess_consts_dword_offset));
        VAddr tess_constants_addr = buf.base_address;
        memcpy(&tess_constants,
               reinterpret_cast<TessellationDataConstantBuffer*>(tess_constants_addr),
               sizeof(tess_constants));
    }
};

constexpr AmdGpu::Buffer BufferResource::GetSharp(const Info& info) const noexcept {
    return inline_cbuf ? inline_cbuf : info.ReadUdSharp<AmdGpu::Buffer>(sharp_idx);
}

constexpr AmdGpu::Image ImageResource::GetSharp(const Info& info) const noexcept {
    const auto image = info.ReadUdSharp<AmdGpu::Image>(sharp_idx);
    if (!image.Valid()) {
        // Fall back to null image if unbound.
        return AmdGpu::Image::Null();
    }
    return image;
}

constexpr AmdGpu::Sampler SamplerResource::GetSharp(const Info& info) const noexcept {
    return inline_sampler ? inline_sampler : info.ReadUdSharp<AmdGpu::Sampler>(sharp_idx);
}

constexpr AmdGpu::Image FMaskResource::GetSharp(const Info& info) const noexcept {
    return info.ReadUdSharp<AmdGpu::Image>(sharp_idx);
}

} // namespace Shader
